Skip to content

Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306

Open
TOPAPEC wants to merge 7 commits intoMTSWebServices:mainfrom
TOPAPEC:feat/unisrec-model
Open

Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306
TOPAPEC wants to merge 7 commits intoMTSWebServices:mainfrom
TOPAPEC:feat/unisrec-model

Conversation

@TOPAPEC
Copy link
Copy Markdown

@TOPAPEC TOPAPEC commented Apr 24, 2026

New rectools.fast_transformers module — standalone transformer sequential recommenders that work with raw torch tensors, without going through Dataset/pandas.

  • GPU-native preprocessing. build_sequences() builds left-padded interaction sequences entirely in torch (argsort + scatter). On ML-20M (20M interactions) this takes 0.5s vs 14.6s for the pandas-based SASRecDataPreparator — roughly 30x faster. For larger production data the problem is even worse - KION prod dataset aggregation for a period of only half a year takes up to 50 minutes only on current rectools code to preprocess data, while train take comparable time to finish.

  • FlatSASRec. Pre-norm SASRec encoder with plain id-embeddings, no ItemNet hierarchy. Wraps into FlatSASRecModel (inherits ModelBase) so it plugs into standard RecTools fit/recommend.

  • UniSRec. Three-phase sequential recommender with pretrained text embeddings and a learnable PCA adaptor:

    1. SASRec warm-up on ID embeddings (transformer + item_emb)
    2. Adaptor-only training (transformer frozen, pretrained embeddings)
    3. Full fine-tune (adaptor + transformer on pretrained embeddings)

    UniSRecModel.fit(user_ids, item_ids, timestamps) takes raw tensors end-to-end. Supports softmax/BCE/gBCE/sampled_softmax losses, Adam/AdamW, cosine warmup scheduler, gradient clipping, early stopping, checkpoint save/load. FFN blocks are configurable (conv1d, linear_gelu, linear_relu).

  • rank_topk() — batched top-k with CSR viewed-item filtering and whitelist support.

Benchmark (ML-20M, 10 epochs, softmax, Adam, n_factors=256)

SASRec UniSRec ID
Preprocessing 14.6s 0.5s
Training 911.8s 639.5s
Evaluation (138K users) 175.6s 28.0s
Total 1102s 668s
HR@10 NDCG@10 MRR@10
SASRec 0.2417 0.1410 0.1103
UniSRec ID 0.2528 0.1495 0.1179

UniSRec ID: +4.6% HR@10, +6.0% NDCG@10, 1.65x faster overall.

New files

Source (9 modules, 1683 lines):

  • rectools/fast_transformers/gpu_data.pybuild_sequences, align_embeddings, GPUBatchDataset, make_dataloader
  • rectools/fast_transformers/net.pyFlatSASRec, SASRecBlock
  • rectools/fast_transformers/lightning_wrap.pyFlatSASRecLightning
  • rectools/fast_transformers/model.pyFlatSASRecModel, FlatSASRecConfig
  • rectools/fast_transformers/ranking.pyrank_topk
  • rectools/fast_transformers/unisrec_net.pyUniSRec, FeedForward, make_ffn
  • rectools/fast_transformers/unisrec_lightning.pyUniSRecLightning, loss/optimizer/scheduler dispatch
  • rectools/fast_transformers/unisrec_model.pyUniSRecModel (three-phase fit, checkpoint)

Tests (143 tests, 1920 lines):

  • tests/fast_transformers/test_gpu_data.py — sequence building, alignment, dataset/dataloader
  • tests/fast_transformers/test_net.py, test_lightning_wrap.py, test_model.py — FlatSASRec stack
  • tests/fast_transformers/test_unisrec_net.py, test_unisrec_lightning.py, test_unisrec_model.py — UniSRec stack
  • tests/fast_transformers/test_ranking.py — top-k, filtering, edge cases

Scripts:

  • scripts/compare_sasrec_unisrec.py — full benchmark with markdown report generation
  • scripts/comparison_report.md — benchmark results

Test plan

  • All 143 tests pass (pytest tests/fast_transformers/ -q)
  • Run on GPU to confirm CUDA path works
  • Verify FlatSASRecModel fit/recommend through the standard RecTools API on a small dataset

TOPAPEC and others added 4 commits April 22, 2026 18:28
Standalone sequential recommender package, mimics ModelBase interface
without touching existing rectools code.

FlatSASRec - plain ID-embedding SASRec encoder.
UniSRec - pretrained text embeddings + PCA/BN adaptor, 3-phase training
(ID emb -> adaptor only -> full finetune).

Uses lightweight rank_topk instead of TorchRanker, reuses
SASRecDataPreparator for the data pipeline.

30 tests, smoke scripts for both models.

Fix: NaN*0=NaN in IEEE 754 breaks attention padding masking via
multiplication, switched to masked_fill.
New config options:
- ffn_type: conv1d / linear_gelu / linear_relu + ffn_expansion
- optimizer: adam / adamw
- scheduler: cosine_warmup (with warmup_ratio, min_lr_ratio)
- loss: softmax / BCE / gBCE / sampled_softmax (with gbce_t)
- patience: early stopping via EarlyStopping callback + val split
- data_preparator: accept custom preparator instance

31 tests passing.
@TOPAPEC TOPAPEC changed the title Feat/unisrec model Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing Apr 24, 2026
TOPAPEC added 3 commits April 24, 2026 22:17
- Add hash-based ID mapping (splitmix64) as alternative to dense
  torch.unique mapping in build_sequences and align_embeddings.
- Add UniSRecModel.export_to_onnx() for native ONNX export of
  encoder and item embeddings (project_all).
- Add UniSRecModel.map_item_ids() for external→internal ID conversion
  at inference time (works for both dense and hash modes).
- Remove FlatSASRecModel/FlatSASRecLightning (RecTools-coupled wrappers
  that duplicated UniSRecModel functionality).
- Add tests: hash mapping (including string-derived IDs),
  ONNX export roundtrip, map_item_ids for both modes.
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new rectools.fast_transformers subpackage providing GPU-native preprocessing and standalone sequential transformer recommenders (FlatSASRec + UniSRec), plus ranking utilities, scripts, and comprehensive tests.

Changes:

  • Introduces torch-native sequence building (build_sequences), embedding alignment, and lightweight dataset/dataloader helpers.
  • Adds UniSRec (pretrained text embeddings + adaptor + SASRec encoder) with Lightning training wrapper and a standalone UniSRecModel API (fit/checkpoint/ONNX export).
  • Adds rank_topk() for batched scoring with CSR filtering + whitelist, along with benchmark scripts and extensive test coverage.

Reviewed changes

Copilot reviewed 17 out of 19 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
rectools/fast_transformers/init.py Exposes the new fast_transformers public API surface.
rectools/fast_transformers/gpu_data.py Implements torch-native preprocessing utilities (sequence building, embedding alignment, dataloader helpers).
rectools/fast_transformers/net.py Adds FlatSASRec network implementation.
rectools/fast_transformers/ranking.py Adds rank_topk() batching + filtering + whitelist ranking utility.
rectools/fast_transformers/unisrec_lightning.py Adds LightningModule wrapper (loss/optimizer/scheduler dispatch) for UniSRec training phases.
rectools/fast_transformers/unisrec_model.py Adds standalone UniSRecModel (3-phase training, checkpointing, ONNX export, ID mapping).
rectools/fast_transformers/unisrec_net.py Adds UniSRec network (adaptor + transformer encoder + helper methods).
tests/fast_transformers/init.py Test package marker for fast_transformers.
tests/fast_transformers/test_gpu_data.py Tests for sequence building, embedding alignment, dataset/dataloader, and hashing.
tests/fast_transformers/test_net.py Tests for FlatSASRec forward paths and encoding helpers.
tests/fast_transformers/test_onnx_export.py Tests ONNX export/roundtrip for UniSRec network and UniSRecModel export.
tests/fast_transformers/test_ranking.py Tests top-k ranking, filtering, whitelist behavior, and edge cases.
tests/fast_transformers/test_unisrec_lightning.py Tests UniSRecLightning configuration + loss/scheduler dispatch behavior.
tests/fast_transformers/test_unisrec_model.py Tests UniSRecModel fit phases, losses/optimizers/schedulers, checkpointing, and mapping.
tests/fast_transformers/test_unisrec_net.py Tests UniSRec network output shapes, adaptor variants, and freeze/unfreeze helpers.
scripts/compare_sasrec_unisrec.py Benchmark script to compare RecTools SASRec vs UniSRec-ID and generate a report.
scripts/comparison_report.md Adds a sample benchmark report output.
CHANGELOG.md Documents the new module and features under Unreleased.
.gitignore Ignores new dev artifacts, model weights, and data folders.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +30 to +41
def build_sequences(
user_ids: torch.Tensor,
item_ids: torch.Tensor,
timestamps: torch.Tensor,
max_len: int,
min_interactions: int = 2,
device: str = "cuda",
id_mapping: str = "dense",
) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
user_ids = user_ids.to(device)
item_ids = item_ids.to(device)
timestamps = timestamps.to(device)
Comment on lines +43 to +49
unique_items = torch.unique(item_ids)
n_unique = len(unique_items)

if id_mapping == "dense":
_, item_inv = torch.unique(item_ids, return_inverse=True)
internal_items = item_inv + 1
elif id_mapping == "hash":
Comment on lines +276 to +307
x, y, unique_items, unique_users = build_sequences(
user_ids,
item_ids,
timestamps,
max_len=self.session_max_len,
min_interactions=self.train_min_user_interactions,
id_mapping=self.id_mapping,
)
self._unique_items = unique_items.cpu()
self._unique_users = unique_users.cpu()
n_items = len(unique_items)

aligned_emb = align_embeddings(self.pretrained_item_embeddings, unique_items, n_items, self.id_mapping)

net = UniSRec(
n_items=n_items,
pretrained_embeddings=aligned_emb,
n_factors=self.n_factors,
projection_hidden=self.projection_hidden,
n_blocks=self.n_blocks,
n_heads=self.n_heads,
session_max_len=self.session_max_len,
dropout=self.dropout,
adaptor_dropout=self.adaptor_dropout,
adaptor_type=self.adaptor_type,
use_adaptor_ffn=self.use_adaptor_ffn,
ffn_type=self.ffn_type,
ffn_expansion=self.ffn_expansion,
)

train_dl = make_dataloader(x, y, batch_size=self.batch_size, shuffle=True)

Comment on lines +448 to +450
lookup = {int(v): i + 1 for i, v in enumerate(self._unique_items.tolist())}
return torch.tensor([lookup.get(int(x), 0) for x in external_ids.tolist()], dtype=torch.long)

Comment on lines +59 to +61
viewed_mask = torch.tensor(batch_csr.toarray(), dtype=torch.bool, device=device)
scores[viewed_mask] = -float("inf")

Comment on lines +37 to +46
def test_padding_invariance(self, net: FlatSASRec) -> None:
"""Different left-padding should produce same last-position embedding."""
net.eval()
# Same content should produce identical output
x_a = torch.tensor([[0, 0, 0, 5, 10]])
x_b = torch.tensor([[0, 0, 0, 5, 10]])
with torch.no_grad():
e_a = net.encode_last(x_a)
e_b = net.encode_last(x_b)
torch.testing.assert_close(e_a, e_b)
Comment on lines +107 to +115
class TestPaddingInvariance:
def test_same_input_same_output(self, net: UniSRec) -> None:
net.eval()
x_a = torch.tensor([[0, 0, 0, 5, 10]])
x_b = torch.tensor([[0, 0, 0, 5, 10]])
with torch.no_grad():
e_a = net.encode_last(x_a, use_id=False)
e_b = net.encode_last(x_b, use_id=False)
torch.testing.assert_close(e_a, e_b)
Comment on lines +306 to +311
train_dl = make_dataloader(x, y, batch_size=self.batch_size, shuffle=True)

val_dl = None
if self.patience is not None:
val_y_last = y[:, -1:]
val_dl = make_dataloader(x, val_y_last, batch_size=self.batch_size, shuffle=False)
Comment on lines +276 to +283
x, y, unique_items, unique_users = build_sequences(
user_ids,
item_ids,
timestamps,
max_len=self.session_max_len,
min_interactions=self.train_min_user_interactions,
id_mapping=self.id_mapping,
)
max_len: int,
min_interactions: int = 2,
device: str = "cuda",
id_mapping: str = "dense",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use Literal for such things

min_interactions: int = 2,
device: str = "cuda",
id_mapping: str = "dense",
) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add extensive docstrings for all the public method, especially for those supposed to be used stand-alone. Here it's especially important since you're returning 4 tensors and user doesn't understand their meaning. Also good to add examples

Comment thread .gitignore
catboost_info/

# Dev artifacts
training_folder/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit weird name, can we remove it?

Comment thread CHANGELOG.md
- `align_embeddings()` for mapping pretrained embedding matrices to internal item ID order
- `GPUBatchDataset` and `make_dataloader()` — lightweight torch Dataset/DataLoader wrappers for sequence training data
- Configurable FFN blocks in `UniSRec`: `conv1d` (original paper), `linear_gelu`, `linear_relu` with adjustable expansion factor
- Tests for all `fast_transformers` submodules (143 tests)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We normally don't add anything that doesn't affect user directly to the changelog, so not much sense to write about the tests

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please put this script and the report to a subfolder in the benchmark folder

return aligned


class GPUBatchDataset(TorchDataset):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the name reflect the purpose

  • why GPU?
  • what does Batch mean?

It also sounds quite "universal" even though I'd say it's more task-specific

y: torch.Tensor,
batch_size: int,
shuffle: bool = True,
transform: tp.Optional[tp.Callable] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend to add **kwargs here to cover different parameters of data loader

On the other side I'm not sure it makes much sense to wrap 2 function calls in a separate function

from scipy import sparse


def rank_topk(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm too lazy to check, could you please describe why do we need it given that we have TorchRanker? Could we reuse the code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants